import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

###########################################
# Link every pair of elements that occur close enough
# in space and time. Links can be non-unique, each cell
# can be linked to more than one other. This is why the 
# later MergeSets script is needed.

def CreateSets_Pairwise1(filepath,df_temp,dist_cut,t_cut):
    
    CellNum_Unique = pd.unique(df_temp["Cell #"])
    df_temp2 = df_temp.copy(deep = True)
    
    for cn in CellNum_Unique:
        
        xpos = df_temp.loc[df_temp["Cell #"] == cn]["x (um)"] 
        ypos = df_temp.loc[df_temp["Cell #"] == cn]["y (um)"]
        time = df_temp.loc[df_temp["Cell #"] == cn]["Time (s)"]
        
        index_vals = xpos.index
        df_temp2.loc[index_vals,"PairClusterSets"] = int(-100)
        
        max_ind = int(xpos.shape[0]) # Get number of events
        
        # Replace index values for looping below
        index = np.linspace(0,max_ind-1,max_ind).astype(int)

        xpos.index = index
        ypos.index = index
        time.index = index
        
        # Initialize a cluster variable. Set all values initially to -1 to indicate
        # not in a cluster. This way any one that is not filled in is defaulted to 
        # not in a cluster
        cluster_cellinds = list(np.empty(xpos.shape))
        clusternum_wrh = np.zeros(xpos.shape) - 1
        
        if max_ind > 1 and cn != 0:
#             print("Cell number = " + str(cn) + ", Number of events = " + str(xpos.shape))

            # Loop over indices
            for i in range(max_ind):
                tmp_set = set({}) # Initialize empty set for this event
                
                x1 = xpos[i] # Get data for first event
                y1 = ypos[i]
                t1 = time[i]
                
                # Get the actual index of this data
                index = np.where(xpos)
                
                for j in range(max_ind):
  
                    x2 = xpos[j] # Get data for second event
                    y2 = ypos[j]
                    t2 = time[j]

                    dtemp = ( (x1-x2)**2 + (y1-y2)**2 )**(0.5)
                    tdiff = abs(t1 - t2)
                    
                    if dtemp < dist_cut and tdiff < t_cut:
                        tmp_set = tmp_set.union({index_vals[j]})
                        
                # If there is no set, record -100 for ease.        
                if len(tmp_set) < 2 :
                    tmp_set = int(-100)
                
                cluster_cellinds[i] = tmp_set
            
            #Put everything back into the data frame
            df_temp2.loc[index_vals,"PairClusterSets"] = cluster_cellinds
            
    return cluster_cellinds , df_temp2


####################################################
# Assign a unique cluster number to each cluster as an ID.

def AssignClusters(df_temp , cluster_start):
    
    df_temp2 = df_temp.copy(deep = True)
    
    index_vals = df_temp.index
    
    nind = df_temp2.shape[0]
    
    running_set = set([])
    
    cluster_count = cluster_start + 1
    for i in range(nind):
        c1_set = df_temp.loc[i,"PairClusterSets"]
        c1_num = df_temp.loc[i,"Cell #"]
        
        if type(c1_set) != set: #Start again with the next cell if this doesnt have a cluster set
            df_temp2.loc[i,"ClusterNum_wrh"] = int(-100)
            df_temp2.loc[i,"PairClusterSets"] = int(-100)
        elif c1_set.issubset(running_set) == False:
            #print("Cluster count = " + str(cluster_count) + ", Set = " , c1_set)
            df_temp2.loc[c1_set,"ClusterNum_wrh"] = int(cluster_count)
            #df_temp2.loc[c1_set,"PairClusterSets"] = int(-100)
            #df_temp.loc[c1_set,"PairClusterSets"] = int(-100)
            cluster_count = cluster_count + 1
            running_set = running_set.union(c1_set)
        else:
            cluster_count = cluster_count + 0
            
            
    return df_temp2 , cluster_count


#########################################################
# This script will take all the pairwise clusterings and 
# merge any that have common elements. After running this 
# twice, every event will be in a unique cluster (or in no
# cluster if appropriate). 
def MergeSets(df_temp):
    
    df_temp2 = df_temp.copy(deep = True)
    
    index_vals = df_temp.index
#     cluster_cellinds = list(np.empty(df_temp.shape[0]))
    cluster_cellinds = df_temp["PairClusterSets"]
    
    nind = df_temp2.shape[0]
    
    new_merge = False
    cluster_ind = 0
    for i in range(nind):
#         c1_set = df_temp.loc[i,"PairClusterSets"]
        c1_set = cluster_cellinds[i]
        c1_num = df_temp.loc[i,"Cell #"]
        
        if type(c1_set) != set: #Start again with the next cell if this doesnt have a cluster set
            continue
            
        for j in range(nind):
#             c2_set = df_temp.loc[j,"PairClusterSets"]
            c2_set = cluster_cellinds[j]
            c2_num = df_temp.loc[j,"Cell #"]
            
            if (j<=i) or (type(c2_set) != set): #Start again with the next cell if this doesnt have a cluster set
                continue
                
            if c1_num != c2_num: # If not in the same cell, no need to compare
                continue
            
            
            if len(c1_set.intersection(c2_set)) > 0 \
               and set(c1_set) != set(c2_set): # Check to see if there is any overlap
                
                c1_set = c1_set.union(c2_set) # Merge sets
#                 print(i , j , c1_set)
#                 print(df_temp2.columns)
#                 print(type(df_temp2.loc[i,"PairClusterSets"]))
#                 print(type(c1_set))
                cluster_cellinds[i] = c1_set # Replace the existing sets with new ones
                cluster_cellinds[j] = c1_set
                new_merge = True
    #print(cluster_cellinds)
#     df_temp2.loc[index_vals,"PairClusterSets"] = cluster_cellinds 
    df_temp2["PairClusterSets"] = cluster_cellinds 
    
    return df_temp2 , new_merge





###############################################
def Get_Time_Between_Events(df_temp):
    ClusterSeries = df_temp["PairClusterSets"]
    
    ClusterSets = ClusterSeries[ClusterSeries != -100]

    running_set = set([])
    running_list = []
    running_tdiff = pd.Series([])
    for index , Set in ClusterSets.items():
    
        if Set.issubset(running_set) == False:
            running_set = running_set.union(Set)
    #         running_list = running_list + [index]
            Times = df_temp.loc[Set,"Time (s)"]
            Times = Times.sort_values(ascending = True)
            Tdiff = Times.diff()
            Tdiff = Tdiff.dropna()
            running_tdiff = running_tdiff.append(Tdiff)
        
    return running_tdiff



###############################################
# This differs from the other Get_Time_Between_Events
# function in that it assumes the index sets were read
# in from a file and thus are a string that needs to
# be evaluated to build an actual set. The line that is 
# different is flagged below.
def Get_Time_Between_Events_SetStr(df_temp):
    ClusterSeries = df_temp["PairClusterSets"]
    
    ClusterSets = ClusterSeries[ClusterSeries != -100]

    running_set = set([])
    running_list = []
    running_tdiff = pd.Series([])
    for index , Set_str in ClusterSets.items():
        Set = eval(Set_str) ##### DIFFERENCE HERE #####
        if type(Set) == set:
            if Set.issubset(running_set) == False:
                running_set = running_set.union(Set)
        #         running_list = running_list + [index]
                Times = df_temp.loc[Set,"Time (s)"]
                Times = Times.sort_values(ascending = True)
                Tdiff = Times.diff()
                Tdiff = Tdiff.dropna()
                running_tdiff = running_tdiff.append(Tdiff)
        
    return running_tdiff





###############################################
def Get_Time_Between_Events_NonCluster(df_temp):
    
    Times = df_temp.loc[df_temp["ClusterNum_wrh"].astype(int) == -100 , "Time (s)"]

    Times = Times.sort_values(ascending = True)
    Tdiff = Times.diff()
    Tdiff = Tdiff.dropna()

    return Tdiff



##############################################
# For each cluster, this code calculates the cluster 
# size and the time between each pair of successive
# events. It then links the cluster number and a unique
# cell number to that cluster.

def cluster_quant(df_temp):
    
    df_temp2 = df_temp.copy(deep = True)
    cell_nums = df_temp["Cell #"].unique()
    number_of_cells = cell_nums.shape[0]
    
    index_vals = df_temp.index
    
    nind = df_temp2.shape[0]
    
    running_set = set([])

    # This array will store the
    cluster_data = np.zeros([1,5]) #################
    
    cluster_count = 0
    for i in range(nind):
        c1_set_str = df_temp.loc[i,"PairClusterSets"]
        c1_num = df_temp.loc[i,"ClusterNum_wrh"]
        
        c1_cell_num = df_temp.loc[i,'Cell #']
        
        if c1_num != -100: # Make sure this is an index worth analyzing

            c1_set = set(eval(c1_set_str)) # Extract the indices of the events clustered with this one
            if not c1_set.issubset(running_set): # Check if you have already accounted for this set.
                Times = df_temp.loc[c1_set,"Time (s)"]
                Times = Times.sort_values(ascending = True)
                Tdiff = Times.diff()
                Tdiff = Tdiff.dropna()
                Tdiff_list = Tdiff.values.tolist()
                
                age = Times.iloc[-1] - Times.iloc[0] ###############
                
                cluster_data = np.append(cluster_data,[[cluster_count , c1_cell_num 
                                        , len(c1_set) , Tdiff_list , age]] , axis = 0)
    
                cluster_count += 1 # Advance the cluster counter by 1.
                running_set = running_set.union(c1_set) # Append this event set so that you don't check it again.

    num_in_cluster = len(running_set) # Get the total number of events in clusters
    num_total = len(index_vals) # Get the total number of events.
    num_noncluster = num_total - num_in_cluster # Total number of events not in clusters.
    
    cluster_data = np.delete(cluster_data,(0),axis = 0) # Remove the first row that just got the structure started.
            
    return cluster_count , num_in_cluster , num_noncluster , \
           cluster_data , number_of_cells
            



